//===- OpDefinition.h - Classes for defining concrete Op types --*- C++ -*-===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
//
// This file implements helper classes for implementing the "Op" types.  This
// includes the Op type, which is the base class for Op class definitions,
// as well as number of traits in the OpTrait namespace that provide a
// declarative way to specify properties of Ops.
//
// The purpose of these types are to allow light-weight implementation of
// concrete ops (like DimOp) with very little boilerplate.
//
//===----------------------------------------------------------------------===//

#ifndef MLIR_IR_OPDEFINITION_H
#define MLIR_IR_OPDEFINITION_H

#include "mlir/IR/Dialect.h"
#include "mlir/IR/Operation.h"
#include "llvm/Support/PointerLikeTypeTraits.h"

#include <type_traits>

namespace mlir {
class Builder;
class OpBuilder;

/// This class implements `Optional` functionality for ParseResult. We don't
/// directly use Optional here, because it provides an implicit conversion
/// to 'bool' which we want to avoid. This class is used to implement tri-state
/// 'parseOptional' functions that may have a failure mode when parsing that
/// shouldn't be attributed to "not present".
class OptionalParseResult {
public:
  OptionalParseResult() = default;
  OptionalParseResult(LogicalResult result) : impl(result) {}
  OptionalParseResult(ParseResult result) : impl(result) {}
  OptionalParseResult(const InFlightDiagnostic &)
      : OptionalParseResult(failure()) {}
  OptionalParseResult(llvm::NoneType) : impl(llvm::None) {}

  /// Returns true if we contain a valid ParseResult value.
  bool hasValue() const { return impl.has_value(); }

  /// Access the internal ParseResult value.
  ParseResult getValue() const { return impl.value(); }
  ParseResult operator*() const { return getValue(); }

private:
  Optional<ParseResult> impl;
};

// These functions are out-of-line utilities, which avoids them being template
// instantiated/duplicated.
namespace impl {
/// Insert an operation, generated by `buildTerminatorOp`, at the end of the
/// region's only block if it does not have a terminator already. If the region
/// is empty, insert a new block first. `buildTerminatorOp` should return the
/// terminator operation to insert.
void ensureRegionTerminator(
    Region &region, OpBuilder &builder, Location loc,
    function_ref<Operation *(OpBuilder &, Location)> buildTerminatorOp);
void ensureRegionTerminator(
    Region &region, Builder &builder, Location loc,
    function_ref<Operation *(OpBuilder &, Location)> buildTerminatorOp);

} // namespace impl

/// This is the concrete base class that holds the operation pointer and has
/// non-generic methods that only depend on State (to avoid having them
/// instantiated on template types that don't affect them.
///
/// This also has the fallback implementations of customization hooks for when
/// they aren't customized.
class OpState {
public:
  /// Ops are pointer-like, so we allow conversion to bool.
  explicit operator bool() { return getOperation() != nullptr; }

  /// This implicitly converts to Operation*.
  operator Operation *() const { return state; }

  /// Shortcut of `->` to access a member of Operation.
  Operation *operator->() const { return state; }

  /// Return the operation that this refers to.
  Operation *getOperation() { return state; }

  /// Return the context this operation belongs to.
  MLIRContext *getContext() { return getOperation()->getContext(); }

  /// Print the operation to the given stream.
  void print(raw_ostream &os, OpPrintingFlags flags = llvm::None) {
    state->print(os, flags);
  }
  void print(raw_ostream &os, AsmState &asmState) {
    state->print(os, asmState);
  }

  /// Dump this operation.
  void dump() { state->dump(); }

  /// The source location the operation was defined or derived from.
  Location getLoc() { return state->getLoc(); }

  /// Return true if there are no users of any results of this operation.
  bool use_empty() { return state->use_empty(); }

  /// Remove this operation from its parent block and delete it.
  void erase() { state->erase(); }

  /// Emit an error with the op name prefixed, like "'dim' op " which is
  /// convenient for verifiers.
  InFlightDiagnostic emitOpError(const Twine &message = {});

  /// Emit an error about fatal conditions with this operation, reporting up to
  /// any diagnostic handlers that may be listening.
  InFlightDiagnostic emitError(const Twine &message = {});

  /// Emit a warning about this operation, reporting up to any diagnostic
  /// handlers that may be listening.
  InFlightDiagnostic emitWarning(const Twine &message = {});

  /// Emit a remark about this operation, reporting up to any diagnostic
  /// handlers that may be listening.
  InFlightDiagnostic emitRemark(const Twine &message = {});

  /// Walk the operation by calling the callback for each nested operation
  /// (including this one), block or region, depending on the callback provided.
  /// Regions, blocks and operations at the same nesting level are visited in
  /// lexicographical order. The walk order for enclosing regions, blocks and
  /// operations with respect to their nested ones is specified by 'Order'
  /// (post-order by default). A callback on a block or operation is allowed to
  /// erase that block or operation if either:
  ///   * the walk is in post-order, or
  ///   * the walk is in pre-order and the walk is skipped after the erasure.
  /// See Operation::walk for more details.
  template <WalkOrder Order = WalkOrder::PostOrder, typename FnT,
            typename RetT = detail::walkResultType<FnT>>
  typename std::enable_if<
      llvm::function_traits<std::decay_t<FnT>>::num_args == 1, RetT>::type
  walk(FnT &&callback) {
    return state->walk<Order>(std::forward<FnT>(callback));
  }

  /// Generic walker with a stage aware callback. Walk the operation by calling
  /// the callback for each nested operation (including this one) N+1 times,
  /// where N is the number of regions attached to that operation.
  ///
  /// The callback method can take any of the following forms:
  ///   void(Operation *, const WalkStage &) : Walk all operation opaquely
  ///     * op.walk([](Operation *nestedOp, const WalkStage &stage) { ...});
  ///   void(OpT, const WalkStage &) : Walk all operations of the given derived
  ///                                  type.
  ///     * op.walk([](ReturnOp returnOp, const WalkStage &stage) { ...});
  ///   WalkResult(Operation*|OpT, const WalkStage &stage) : Walk operations,
  ///          but allow for interruption/skipping.
  ///     * op.walk([](... op, const WalkStage &stage) {
  ///         // Skip the walk of this op based on some invariant.
  ///         if (some_invariant)
  ///           return WalkResult::skip();
  ///         // Interrupt, i.e cancel, the walk based on some invariant.
  ///         if (another_invariant)
  ///           return WalkResult::interrupt();
  ///         return WalkResult::advance();
  ///       });
  template <typename FnT, typename RetT = detail::walkResultType<FnT>>
  typename std::enable_if<
      llvm::function_traits<std::decay_t<FnT>>::num_args == 2, RetT>::type
  walk(FnT &&callback) {
    return state->walk(std::forward<FnT>(callback));
  }

  // These are default implementations of customization hooks.
public:
  /// This hook returns any canonicalization pattern rewrites that the operation
  /// supports, for use by the canonicalization pass.
  static void getCanonicalizationPatterns(RewritePatternSet &results,
                                          MLIRContext *context) {}

  /// This hook populates any unset default attrs.
  static void populateDefaultAttrs(const RegisteredOperationName &,
                                   NamedAttrList &) {}

protected:
  /// If the concrete type didn't implement a custom verifier hook, just fall
  /// back to this one which accepts everything.
  LogicalResult verify() { return success(); }
  LogicalResult verifyRegions() { return success(); }

  /// Parse the custom form of an operation. Unless overridden, this method will
  /// first try to get an operation parser from the op's dialect. Otherwise the
  /// custom assembly form of an op is always rejected. Op implementations
  /// should implement this to return failure. On success, they should fill in
  /// result with the fields to use.
  static ParseResult parse(OpAsmParser &parser, OperationState &result);

  /// Print the operation. Unless overridden, this method will first try to get
  /// an operation printer from the dialect. Otherwise, it prints the operation
  /// in generic form.
  static void print(Operation *op, OpAsmPrinter &p, StringRef defaultDialect);

  /// Print an operation name, eliding the dialect prefix if necessary.
  static void printOpName(Operation *op, OpAsmPrinter &p,
                          StringRef defaultDialect);

  /// Mutability management is handled by the OpWrapper/OpConstWrapper classes,
  /// so we can cast it away here.
  explicit OpState(Operation *state) : state(state) {}

private:
  Operation *state;

  /// Allow access to internal hook implementation methods.
  friend RegisteredOperationName;
};

// Allow comparing operators.
inline bool operator==(OpState lhs, OpState rhs) {
  return lhs.getOperation() == rhs.getOperation();
}
inline bool operator!=(OpState lhs, OpState rhs) {
  return lhs.getOperation() != rhs.getOperation();
}

raw_ostream &operator<<(raw_ostream &os, OpFoldResult ofr);

/// This class represents a single result from folding an operation.
class OpFoldResult : public PointerUnion<Attribute, Value> {
  using PointerUnion<Attribute, Value>::PointerUnion;

public:
  void dump() { llvm::errs() << *this << "\n"; }
};

/// Allow printing to a stream.
inline raw_ostream &operator<<(raw_ostream &os, OpFoldResult ofr) {
  if (Value value = ofr.dyn_cast<Value>())
    value.print(os);
  else
    ofr.dyn_cast<Attribute>().print(os);
  return os;
}

/// Allow printing to a stream.
inline raw_ostream &operator<<(raw_ostream &os, OpState op) {
  op.print(os, OpPrintingFlags().useLocalScope());
  return os;
}

//===----------------------------------------------------------------------===//
// Operation Trait Types
//===----------------------------------------------------------------------===//

namespace OpTrait {

// These functions are out-of-line implementations of the methods in the
// corresponding trait classes.  This avoids them being template
// instantiated/duplicated.
namespace impl {
OpFoldResult foldIdempotent(Operation *op);
OpFoldResult foldInvolution(Operation *op);
LogicalResult verifyZeroOperands(Operation *op);
LogicalResult verifyOneOperand(Operation *op);
LogicalResult verifyNOperands(Operation *op, unsigned numOperands);
LogicalResult verifyIsIdempotent(Operation *op);
LogicalResult verifyIsInvolution(Operation *op);
LogicalResult verifyAtLeastNOperands(Operation *op, unsigned numOperands);
LogicalResult verifyOperandsAreFloatLike(Operation *op);
LogicalResult verifyOperandsAreSignlessIntegerLike(Operation *op);
LogicalResult verifySameTypeOperands(Operation *op);
LogicalResult verifyZeroRegions(Operation *op);
LogicalResult verifyOneRegion(Operation *op);
LogicalResult verifyNRegions(Operation *op, unsigned numRegions);
LogicalResult verifyAtLeastNRegions(Operation *op, unsigned numRegions);
LogicalResult verifyZeroResults(Operation *op);
LogicalResult verifyOneResult(Operation *op);
LogicalResult verifyNResults(Operation *op, unsigned numOperands);
LogicalResult verifyAtLeastNResults(Operation *op, unsigned numOperands);
LogicalResult verifySameOperandsShape(Operation *op);
LogicalResult verifySameOperandsAndResultShape(Operation *op);
LogicalResult verifySameOperandsElementType(Operation *op);
LogicalResult verifySameOperandsAndResultElementType(Operation *op);
LogicalResult verifySameOperandsAndResultType(Operation *op);
LogicalResult verifyResultsAreBoolLike(Operation *op);
LogicalResult verifyResultsAreFloatLike(Operation *op);
LogicalResult verifyResultsAreSignlessIntegerLike(Operation *op);
LogicalResult verifyIsTerminator(Operation *op);
LogicalResult verifyZeroSuccessors(Operation *op);
LogicalResult verifyOneSuccessor(Operation *op);
LogicalResult verifyNSuccessors(Operation *op, unsigned numSuccessors);
LogicalResult verifyAtLeastNSuccessors(Operation *op, unsigned numSuccessors);
LogicalResult verifyValueSizeAttr(Operation *op, StringRef attrName,
                                  StringRef valueGroupName,
                                  size_t expectedCount);
LogicalResult verifyOperandSizeAttr(Operation *op, StringRef sizeAttrName);
LogicalResult verifyResultSizeAttr(Operation *op, StringRef sizeAttrName);
LogicalResult verifyNoRegionArguments(Operation *op);
LogicalResult verifyElementwise(Operation *op);
LogicalResult verifyIsIsolatedFromAbove(Operation *op);
} // namespace impl

/// Helper class for implementing traits.  Clients are not expected to interact
/// with this directly, so its members are all protected.
template <typename ConcreteType, template <typename> class TraitType>
class TraitBase {
protected:
  /// Return the ultimate Operation being worked on.
  Operation *getOperation() {
    auto *concrete = static_cast<ConcreteType *>(this);
    return concrete->getOperation();
  }
};

//===----------------------------------------------------------------------===//
// Operand Traits

namespace detail {
/// Utility trait base that provides accessors for derived traits that have
/// multiple operands.
template <typename ConcreteType, template <typename> class TraitType>
struct MultiOperandTraitBase : public TraitBase<ConcreteType, TraitType> {
  using operand_iterator = Operation::operand_iterator;
  using operand_range = Operation::operand_range;
  using operand_type_iterator = Operation::operand_type_iterator;
  using operand_type_range = Operation::operand_type_range;

  /// Return the number of operands.
  unsigned getNumOperands() { return this->getOperation()->getNumOperands(); }

  /// Return the operand at index 'i'.
  Value getOperand(unsigned i) { return this->getOperation()->getOperand(i); }

  /// Set the operand at index 'i' to 'value'.
  void setOperand(unsigned i, Value value) {
    this->getOperation()->setOperand(i, value);
  }

  /// Operand iterator access.
  operand_iterator operand_begin() {
    return this->getOperation()->operand_begin();
  }
  operand_iterator operand_end() { return this->getOperation()->operand_end(); }
  operand_range getOperands() { return this->getOperation()->getOperands(); }

  /// Operand type access.
  operand_type_iterator operand_type_begin() {
    return this->getOperation()->operand_type_begin();
  }
  operand_type_iterator operand_type_end() {
    return this->getOperation()->operand_type_end();
  }
  operand_type_range getOperandTypes() {
    return this->getOperation()->getOperandTypes();
  }
};
} // namespace detail

/// `verifyInvariantsImpl` verifies the invariants like the types, attrs, .etc.
/// It should be run after core traits and before any other user defined traits.
/// In order to run it in the correct order, wrap it with OpInvariants trait so
/// that tblgen will be able to put it in the right order.
template <typename ConcreteType>
class OpInvariants : public TraitBase<ConcreteType, OpInvariants> {
public:
  static LogicalResult verifyTrait(Operation *op) {
    return cast<ConcreteType>(op).verifyInvariantsImpl();
  }
};

/// This class provides the API for ops that are known to have no
/// SSA operand.
template <typename ConcreteType>
class ZeroOperands : public TraitBase<ConcreteType, ZeroOperands> {
public:
  static LogicalResult verifyTrait(Operation *op) {
    return impl::verifyZeroOperands(op);
  }

private:
  // Disable these.
  void getOperand() {}
  void setOperand() {}
};

/// This class provides the API for ops that are known to have exactly one
/// SSA operand.
template <typename ConcreteType>
class OneOperand : public TraitBase<ConcreteType, OneOperand> {
public:
  Value getOperand() { return this->getOperation()->getOperand(0); }

  void setOperand(Value value) { this->getOperation()->setOperand(0, value); }

  static LogicalResult verifyTrait(Operation *op) {
    return impl::verifyOneOperand(op);
  }
};

/// This class provides the API for ops that are known to have a specified
/// number of operands.  This is used as a trait like this:
///
///   class FooOp : public Op<FooOp, OpTrait::NOperands<2>::Impl> {
///
template <unsigned N>
class NOperands {
public:
  static_assert(N > 1, "use ZeroOperands/OneOperand for N < 2");

  template <typename ConcreteType>
  class Impl
      : public detail::MultiOperandTraitBase<ConcreteType, NOperands<N>::Impl> {
  public:
    static LogicalResult verifyTrait(Operation *op) {
      return impl::verifyNOperands(op, N);
    }
  };
};

/// This class provides the API for ops that are known to have a at least a
/// specified number of operands.  This is used as a trait like this:
///
///   class FooOp : public Op<FooOp, OpTrait::AtLeastNOperands<2>::Impl> {
///
template <unsigned N>
class AtLeastNOperands {
public:
  template <typename ConcreteType>
  class Impl : public detail::MultiOperandTraitBase<ConcreteType,
                                                    AtLeastNOperands<N>::Impl> {
  public:
    static LogicalResult verifyTrait(Operation *op) {
      return impl::verifyAtLeastNOperands(op, N);
    }
  };
};

/// This class provides the API for ops which have an unknown number of
/// SSA operands.
template <typename ConcreteType>
class VariadicOperands
    : public detail::MultiOperandTraitBase<ConcreteType, VariadicOperands> {};

//===----------------------------------------------------------------------===//
// Region Traits

/// This class provides verification for ops that are known to have zero
/// regions.
template <typename ConcreteType>
class ZeroRegions : public TraitBase<ConcreteType, ZeroRegions> {
public:
  static LogicalResult verifyTrait(Operation *op) {
    return impl::verifyZeroRegions(op);
  }
};

namespace detail {
/// Utility trait base that provides accessors for derived traits that have
/// multiple regions.
template <typename ConcreteType, template <typename> class TraitType>
struct MultiRegionTraitBase : public TraitBase<ConcreteType, TraitType> {
  using region_iterator = MutableArrayRef<Region>;
  using region_range = RegionRange;

  /// Return the number of regions.
  unsigned getNumRegions() { return this->getOperation()->getNumRegions(); }

  /// Return the region at `index`.
  Region &getRegion(unsigned i) { return this->getOperation()->getRegion(i); }

  /// Region iterator access.
  region_iterator region_begin() {
    return this->getOperation()->region_begin();
  }
  region_iterator region_end() { return this->getOperation()->region_end(); }
  region_range getRegions() { return this->getOperation()->getRegions(); }
};
} // namespace detail

/// This class provides APIs for ops that are known to have a single region.
template <typename ConcreteType>
class OneRegion : public TraitBase<ConcreteType, OneRegion> {
public:
  Region &getRegion() { return this->getOperation()->getRegion(0); }

  /// Returns a range of operations within the region of this operation.
  auto getOps() { return getRegion().getOps(); }
  template <typename OpT>
  auto getOps() {
    return getRegion().template getOps<OpT>();
  }

  static LogicalResult verifyTrait(Operation *op) {
    return impl::verifyOneRegion(op);
  }
};

/// This class provides the API for ops that are known to have a specified
/// number of regions.
template <unsigned N>
class NRegions {
public:
  static_assert(N > 1, "use ZeroRegions/OneRegion for N < 2");

  template <typename ConcreteType>
  class Impl
      : public detail::MultiRegionTraitBase<ConcreteType, NRegions<N>::Impl> {
  public:
    static LogicalResult verifyTrait(Operation *op) {
      return impl::verifyNRegions(op, N);
    }
  };
};

/// This class provides APIs for ops that are known to have at least a specified
/// number of regions.
template <unsigned N>
class AtLeastNRegions {
public:
  template <typename ConcreteType>
  class Impl : public detail::MultiRegionTraitBase<ConcreteType,
                                                   AtLeastNRegions<N>::Impl> {
  public:
    static LogicalResult verifyTrait(Operation *op) {
      return impl::verifyAtLeastNRegions(op, N);
    }
  };
};

/// This class provides the API for ops which have an unknown number of
/// regions.
template <typename ConcreteType>
class VariadicRegions
    : public detail::MultiRegionTraitBase<ConcreteType, VariadicRegions> {};

//===----------------------------------------------------------------------===//
// Result Traits

/// This class provides return value APIs for ops that are known to have
/// zero results.
template <typename ConcreteType>
class ZeroResults : public TraitBase<ConcreteType, ZeroResults> {
public:
  static LogicalResult verifyTrait(Operation *op) {
    return impl::verifyZeroResults(op);
  }
};

namespace detail {
/// Utility trait base that provides accessors for derived traits that have
/// multiple results.
template <typename ConcreteType, template <typename> class TraitType>
struct MultiResultTraitBase : public TraitBase<ConcreteType, TraitType> {
  using result_iterator = Operation::result_iterator;
  using result_range = Operation::result_range;
  using result_type_iterator = Operation::result_type_iterator;
  using result_type_range = Operation::result_type_range;

  /// Return the number of results.
  unsigned getNumResults() { return this->getOperation()->getNumResults(); }

  /// Return the result at index 'i'.
  Value getResult(unsigned i) { return this->getOperation()->getResult(i); }

  /// Replace all uses of results of this operation with the provided 'values'.
  /// 'values' may correspond to an existing operation, or a range of 'Value'.
  template <typename ValuesT>
  void replaceAllUsesWith(ValuesT &&values) {
    this->getOperation()->replaceAllUsesWith(std::forward<ValuesT>(values));
  }

  /// Return the type of the `i`-th result.
  Type getType(unsigned i) { return getResult(i).getType(); }

  /// Result iterator access.
  result_iterator result_begin() {
    return this->getOperation()->result_begin();
  }
  result_iterator result_end() { return this->getOperation()->result_end(); }
  result_range getResults() { return this->getOperation()->getResults(); }

  /// Result type access.
  result_type_iterator result_type_begin() {
    return this->getOperation()->result_type_begin();
  }
  result_type_iterator result_type_end() {
    return this->getOperation()->result_type_end();
  }
  result_type_range getResultTypes() {
    return this->getOperation()->getResultTypes();
  }
};
} // namespace detail

/// This class provides return value APIs for ops that are known to have a
/// single result.  ResultType is the concrete type returned by getType().
template <typename ConcreteType>
class OneResult : public TraitBase<ConcreteType, OneResult> {
public:
  Value getResult() { return this->getOperation()->getResult(0); }

  /// If the operation returns a single value, then the Op can be implicitly
  /// converted to an Value. This yields the value of the only result.
  operator Value() { return getResult(); }

  /// Replace all uses of 'this' value with the new value, updating anything
  /// in the IR that uses 'this' to use the other value instead.  When this
  /// returns there are zero uses of 'this'.
  void replaceAllUsesWith(Value newValue) {
    getResult().replaceAllUsesWith(newValue);
  }

  /// Replace all uses of 'this' value with the result of 'op'.
  void replaceAllUsesWith(Operation *op) {
    this->getOperation()->replaceAllUsesWith(op);
  }

  static LogicalResult verifyTrait(Operation *op) {
    return impl::verifyOneResult(op);
  }
};

/// This trait is used for return value APIs for ops that are known to have a
/// specific type other than `Type`.  This allows the "getType()" member to be
/// more specific for an op.  This should be used in conjunction with OneResult,
/// and occur in the trait list before OneResult.
template <typename ResultType>
class OneTypedResult {
public:
  /// This class provides return value APIs for ops that are known to have a
  /// single result.  ResultType is the concrete type returned by getType().
  template <typename ConcreteType>
  class Impl
      : public TraitBase<ConcreteType, OneTypedResult<ResultType>::Impl> {
  public:
    ResultType getType() {
      auto resultTy = this->getOperation()->getResult(0).getType();
      return resultTy.template cast<ResultType>();
    }
  };
};

/// This class provides the API for ops that are known to have a specified
/// number of results.  This is used as a trait like this:
///
///   class FooOp : public Op<FooOp, OpTrait::NResults<2>::Impl> {
///
template <unsigned N>
class NResults {
public:
  static_assert(N > 1, "use ZeroResults/OneResult for N < 2");

  template <typename ConcreteType>
  class Impl
      : public detail::MultiResultTraitBase<ConcreteType, NResults<N>::Impl> {
  public:
    static LogicalResult verifyTrait(Operation *op) {
      return impl::verifyNResults(op, N);
    }
  };
};

/// This class provides the API for ops that are known to have at least a
/// specified number of results.  This is used as a trait like this:
///
///   class FooOp : public Op<FooOp, OpTrait::AtLeastNResults<2>::Impl> {
///
template <unsigned N>
class AtLeastNResults {
public:
  template <typename ConcreteType>
  class Impl : public detail::MultiResultTraitBase<ConcreteType,
                                                   AtLeastNResults<N>::Impl> {
  public:
    static LogicalResult verifyTrait(Operation *op) {
      return impl::verifyAtLeastNResults(op, N);
    }
  };
};

/// This class provides the API for ops which have an unknown number of
/// results.
template <typename ConcreteType>
class VariadicResults
    : public detail::MultiResultTraitBase<ConcreteType, VariadicResults> {};

//===----------------------------------------------------------------------===//
// Terminator Traits

/// This class indicates that the regions associated with this op don't have
/// terminators.
template <typename ConcreteType>
class NoTerminator : public TraitBase<ConcreteType, NoTerminator> {};

/// This class provides the API for ops that are known to be terminators.
template <typename ConcreteType>
class IsTerminator : public TraitBase<ConcreteType, IsTerminator> {
public:
  static LogicalResult verifyTrait(Operation *op) {
    return impl::verifyIsTerminator(op);
  }
};

/// This class provides verification for ops that are known to have zero
/// successors.
template <typename ConcreteType>
class ZeroSuccessors : public TraitBase<ConcreteType, ZeroSuccessors> {
public:
  static LogicalResult verifyTrait(Operation *op) {
    return impl::verifyZeroSuccessors(op);
  }
};

namespace detail {
/// Utility trait base that provides accessors for derived traits that have
/// multiple successors.
template <typename ConcreteType, template <typename> class TraitType>
struct MultiSuccessorTraitBase : public TraitBase<ConcreteType, TraitType> {
  using succ_iterator = Operation::succ_iterator;
  using succ_range = SuccessorRange;

  /// Return the number of successors.
  unsigned getNumSuccessors() {
    return this->getOperation()->getNumSuccessors();
  }

  /// Return the successor at `index`.
  Block *getSuccessor(unsigned i) {
    return this->getOperation()->getSuccessor(i);
  }

  /// Set the successor at `index`.
  void setSuccessor(Block *block, unsigned i) {
    return this->getOperation()->setSuccessor(block, i);
  }

  /// Successor iterator access.
  succ_iterator succ_begin() { return this->getOperation()->succ_begin(); }
  succ_iterator succ_end() { return this->getOperation()->succ_end(); }
  succ_range getSuccessors() { return this->getOperation()->getSuccessors(); }
};
} // namespace detail

/// This class provides APIs for ops that are known to have a single successor.
template <typename ConcreteType>
class OneSuccessor : public TraitBase<ConcreteType, OneSuccessor> {
public:
  Block *getSuccessor() { return this->getOperation()->getSuccessor(0); }
  void setSuccessor(Block *succ) {
    this->getOperation()->setSuccessor(succ, 0);
  }

  static LogicalResult verifyTrait(Operation *op) {
    return impl::verifyOneSuccessor(op);
  }
};

/// This class provides the API for ops that are known to have a specified
/// number of successors.
template <unsigned N>
class NSuccessors {
public:
  static_assert(N > 1, "use ZeroSuccessors/OneSuccessor for N < 2");

  template <typename ConcreteType>
  class Impl : public detail::MultiSuccessorTraitBase<ConcreteType,
                                                      NSuccessors<N>::Impl> {
  public:
    static LogicalResult verifyTrait(Operation *op) {
      return impl::verifyNSuccessors(op, N);
    }
  };
};

/// This class provides APIs for ops that are known to have at least a specified
/// number of successors.
template <unsigned N>
class AtLeastNSuccessors {
public:
  template <typename ConcreteType>
  class Impl
      : public detail::MultiSuccessorTraitBase<ConcreteType,
                                               AtLeastNSuccessors<N>::Impl> {
  public:
    static LogicalResult verifyTrait(Operation *op) {
      return impl::verifyAtLeastNSuccessors(op, N);
    }
  };
};

/// This class provides the API for ops which have an unknown number of
/// successors.
template <typename ConcreteType>
class VariadicSuccessors
    : public detail::MultiSuccessorTraitBase<ConcreteType, VariadicSuccessors> {
};

//===----------------------------------------------------------------------===//
// SingleBlock

/// This class provides APIs and verifiers for ops with regions having a single
/// block.
template <typename ConcreteType>
struct SingleBlock : public TraitBase<ConcreteType, SingleBlock> {
public:
  static LogicalResult verifyTrait(Operation *op) {
    for (unsigned i = 0, e = op->getNumRegions(); i < e; ++i) {
      Region &region = op->getRegion(i);

      // Empty regions are fine.
      if (region.empty())
        continue;

      // Non-empty regions must contain a single basic block.
      if (!llvm::hasSingleElement(region))
        return op->emitOpError("expects region #")
               << i << " to have 0 or 1 blocks";

      if (!ConcreteType::template hasTrait<NoTerminator>()) {
        Block &block = region.front();
        if (block.empty())
          return op->emitOpError() << "expects a non-empty block";
      }
    }
    return success();
  }

  Block *getBody(unsigned idx = 0) {
    Region &region = this->getOperation()->getRegion(idx);
    assert(!region.empty() && "unexpected empty region");
    return &region.front();
  }
  Region &getBodyRegion(unsigned idx = 0) {
    return this->getOperation()->getRegion(idx);
  }

  //===------------------------------------------------------------------===//
  // Single Region Utilities
  //===------------------------------------------------------------------===//

  /// The following are a set of methods only enabled when the parent
  /// operation has a single region. Each of these methods take an additional
  /// template parameter that represents the concrete operation so that we
  /// can use SFINAE to disable the methods for non-single region operations.
  template <typename OpT, typename T = void>
  using enable_if_single_region =
      typename std::enable_if_t<OpT::template hasTrait<OneRegion>(), T>;

  template <typename OpT = ConcreteType>
  enable_if_single_region<OpT, Block::iterator> begin() {
    return getBody()->begin();
  }
  template <typename OpT = ConcreteType>
  enable_if_single_region<OpT, Block::iterator> end() {
    return getBody()->end();
  }
  template <typename OpT = ConcreteType>
  enable_if_single_region<OpT, Operation &> front() {
    return *begin();
  }

  /// Insert the operation into the back of the body.
  template <typename OpT = ConcreteType>
  enable_if_single_region<OpT> push_back(Operation *op) {
    insert(Block::iterator(getBody()->end()), op);
  }

  /// Insert the operation at the given insertion point.
  template <typename OpT = ConcreteType>
  enable_if_single_region<OpT> insert(Operation *insertPt, Operation *op) {
    insert(Block::iterator(insertPt), op);
  }
  template <typename OpT = ConcreteType>
  enable_if_single_region<OpT> insert(Block::iterator insertPt, Operation *op) {
    getBody()->getOperations().insert(insertPt, op);
  }
};

//===----------------------------------------------------------------------===//
// SingleBlockImplicitTerminator

/// This class provides APIs and verifiers for ops with regions having a single
/// block that must terminate with `TerminatorOpType`.
template <typename TerminatorOpType>
struct SingleBlockImplicitTerminator {
  template <typename ConcreteType>
  class Impl : public SingleBlock<ConcreteType> {
  private:
    using Base = SingleBlock<ConcreteType>;
    /// Builds a terminator operation without relying on OpBuilder APIs to avoid
    /// cyclic header inclusion.
    static Operation *buildTerminator(OpBuilder &builder, Location loc) {
      OperationState state(loc, TerminatorOpType::getOperationName());
      TerminatorOpType::build(builder, state);
      return Operation::create(state);
    }

  public:
    /// The type of the operation used as the implicit terminator type.
    using ImplicitTerminatorOpT = TerminatorOpType;

    static LogicalResult verifyRegionTrait(Operation *op) {
      if (failed(Base::verifyTrait(op)))
        return failure();
      for (unsigned i = 0, e = op->getNumRegions(); i < e; ++i) {
        Region &region = op->getRegion(i);
        // Empty regions are fine.
        if (region.empty())
          continue;
        Operation &terminator = region.front().back();
        if (isa<TerminatorOpType>(terminator))
          continue;

        return op->emitOpError("expects regions to end with '" +
                               TerminatorOpType::getOperationName() +
                               "', found '" +
                               terminator.getName().getStringRef() + "'")
                   .attachNote()
               << "in custom textual format, the absence of terminator implies "
                  "'"
               << TerminatorOpType::getOperationName() << '\'';
      }

      return success();
    }

    /// Ensure that the given region has the terminator required by this trait.
    /// If OpBuilder is provided, use it to build the terminator and notify the
    /// OpBuilder listeners accordingly. If only a Builder is provided, locally
    /// construct an OpBuilder with no listeners; this should only be used if no
    /// OpBuilder is available at the call site, e.g., in the parser.
    static void ensureTerminator(Region &region, Builder &builder,
                                 Location loc) {
      ::mlir::impl::ensureRegionTerminator(region, builder, loc,
                                           buildTerminator);
    }
    static void ensureTerminator(Region &region, OpBuilder &builder,
                                 Location loc) {
      ::mlir::impl::ensureRegionTerminator(region, builder, loc,
                                           buildTerminator);
    }

    //===------------------------------------------------------------------===//
    // Single Region Utilities
    //===------------------------------------------------------------------===//
    using Base::getBody;

    template <typename OpT, typename T = void>
    using enable_if_single_region =
        typename std::enable_if_t<OpT::template hasTrait<OneRegion>(), T>;

    /// Insert the operation into the back of the body, before the terminator.
    template <typename OpT = ConcreteType>
    enable_if_single_region<OpT> push_back(Operation *op) {
      insert(Block::iterator(getBody()->getTerminator()), op);
    }

    /// Insert the operation at the given insertion point. Note: The operation
    /// is never inserted after the terminator, even if the insertion point is
    /// end().
    template <typename OpT = ConcreteType>
    enable_if_single_region<OpT> insert(Operation *insertPt, Operation *op) {
      insert(Block::iterator(insertPt), op);
    }
    template <typename OpT = ConcreteType>
    enable_if_single_region<OpT> insert(Block::iterator insertPt,
                                        Operation *op) {
      auto *body = getBody();
      if (insertPt == body->end())
        insertPt = Block::iterator(body->getTerminator());
      body->getOperations().insert(insertPt, op);
    }
  };
};

/// Check is an op defines the `ImplicitTerminatorOpT` member. This is intended
/// to be used with `llvm::is_detected`.
template <class T>
using has_implicit_terminator_t = typename T::ImplicitTerminatorOpT;

/// Support to check if an operation has the SingleBlockImplicitTerminator
/// trait. We can't just use `hasTrait` because this class is templated on a
/// specific terminator op.
template <class Op, bool hasTerminator =
                        llvm::is_detected<has_implicit_terminator_t, Op>::value>
struct hasSingleBlockImplicitTerminator {
  static constexpr bool value = std::is_base_of<
      typename OpTrait::SingleBlockImplicitTerminator<
          typename Op::ImplicitTerminatorOpT>::template Impl<Op>,
      Op>::value;
};
template <class Op>
struct hasSingleBlockImplicitTerminator<Op, false> {
  static constexpr bool value = false;
};

//===----------------------------------------------------------------------===//
// Misc Traits

/// This class provides verification for ops that are known to have the same
/// operand shape: all operands are scalars, vectors/tensors of the same
/// shape.
template <typename ConcreteType>
class SameOperandsShape : public TraitBase<ConcreteType, SameOperandsShape> {
public:
  static LogicalResult verifyTrait(Operation *op) {
    return impl::verifySameOperandsShape(op);
  }
};

/// This class provides verification for ops that are known to have the same
/// operand and result shape: both are scalars, vectors/tensors of the same
/// shape.
template <typename ConcreteType>
class SameOperandsAndResultShape
    : public TraitBase<ConcreteType, SameOperandsAndResultShape> {
public:
  static LogicalResult verifyTrait(Operation *op) {
    return impl::verifySameOperandsAndResultShape(op);
  }
};

/// This class provides verification for ops that are known to have the same
/// operand element type (or the type itself if it is scalar).
///
template <typename ConcreteType>
class SameOperandsElementType
    : public TraitBase<ConcreteType, SameOperandsElementType> {
public:
  static LogicalResult verifyTrait(Operation *op) {
    return impl::verifySameOperandsElementType(op);
  }
};

/// This class provides verification for ops that are known to have the same
/// operand and result element type (or the type itself if it is scalar).
///
template <typename ConcreteType>
class SameOperandsAndResultElementType
    : public TraitBase<ConcreteType, SameOperandsAndResultElementType> {
public:
  static LogicalResult verifyTrait(Operation *op) {
    return impl::verifySameOperandsAndResultElementType(op);
  }
};

/// This class provides verification for ops that are known to have the same
/// operand and result type.
///
/// Note: this trait subsumes the SameOperandsAndResultShape and
/// SameOperandsAndResultElementType traits.
template <typename ConcreteType>
class SameOperandsAndResultType
    : public TraitBase<ConcreteType, SameOperandsAndResultType> {
public:
  static LogicalResult verifyTrait(Operation *op) {
    return impl::verifySameOperandsAndResultType(op);
  }
};

/// This class verifies that any results of the specified op have a boolean
/// type, a vector thereof, or a tensor thereof.
template <typename ConcreteType>
class ResultsAreBoolLike : public TraitBase<ConcreteType, ResultsAreBoolLike> {
public:
  static LogicalResult verifyTrait(Operation *op) {
    return impl::verifyResultsAreBoolLike(op);
  }
};

/// This class verifies that any results of the specified op have a floating
/// point type, a vector thereof, or a tensor thereof.
template <typename ConcreteType>
class ResultsAreFloatLike
    : public TraitBase<ConcreteType, ResultsAreFloatLike> {
public:
  static LogicalResult verifyTrait(Operation *op) {
    return impl::verifyResultsAreFloatLike(op);
  }
};

/// This class verifies that any results of the specified op have a signless
/// integer or index type, a vector thereof, or a tensor thereof.
template <typename ConcreteType>
class ResultsAreSignlessIntegerLike
    : public TraitBase<ConcreteType, ResultsAreSignlessIntegerLike> {
public:
  static LogicalResult verifyTrait(Operation *op) {
    return impl::verifyResultsAreSignlessIntegerLike(op);
  }
};

/// This class adds property that the operation is commutative.
template <typename ConcreteType>
class IsCommutative : public TraitBase<ConcreteType, IsCommutative> {};

/// This class adds property that the operation is an involution.
/// This means a unary to unary operation "f" that satisfies f(f(x)) = x
template <typename ConcreteType>
class IsInvolution : public TraitBase<ConcreteType, IsInvolution> {
public:
  static LogicalResult verifyTrait(Operation *op) {
    static_assert(ConcreteType::template hasTrait<OneResult>(),
                  "expected operation to produce one result");
    static_assert(ConcreteType::template hasTrait<OneOperand>(),
                  "expected operation to take one operand");
    static_assert(ConcreteType::template hasTrait<SameOperandsAndResultType>(),
                  "expected operation to preserve type");
    // Involution requires the operation to be side effect free as well
    // but currently this check is under a FIXME and is not actually done.
    return impl::verifyIsInvolution(op);
  }

  static OpFoldResult foldTrait(Operation *op, ArrayRef<Attribute> operands) {
    return impl::foldInvolution(op);
  }
};

/// This class adds property that the operation is idempotent.
/// This means a unary to unary operation "f" that satisfies f(f(x)) = f(x),
/// or a binary operation "g" that satisfies g(x, x) = x.
template <typename ConcreteType>
class IsIdempotent : public TraitBase<ConcreteType, IsIdempotent> {
public:
  static LogicalResult verifyTrait(Operation *op) {
    static_assert(ConcreteType::template hasTrait<OneResult>(),
                  "expected operation to produce one result");
    static_assert(ConcreteType::template hasTrait<OneOperand>() ||
                      ConcreteType::template hasTrait<NOperands<2>::Impl>(),
                  "expected operation to take one or two operands");
    static_assert(ConcreteType::template hasTrait<SameOperandsAndResultType>(),
                  "expected operation to preserve type");
    // Idempotent requires the operation to be side effect free as well
    // but currently this check is under a FIXME and is not actually done.
    return impl::verifyIsIdempotent(op);
  }

  static OpFoldResult foldTrait(Operation *op, ArrayRef<Attribute> operands) {
    return impl::foldIdempotent(op);
  }
};

/// This class verifies that all operands of the specified op have a float type,
/// a vector thereof, or a tensor thereof.
template <typename ConcreteType>
class OperandsAreFloatLike
    : public TraitBase<ConcreteType, OperandsAreFloatLike> {
public:
  static LogicalResult verifyTrait(Operation *op) {
    return impl::verifyOperandsAreFloatLike(op);
  }
};

/// This class verifies that all operands of the specified op have a signless
/// integer or index type, a vector thereof, or a tensor thereof.
template <typename ConcreteType>
class OperandsAreSignlessIntegerLike
    : public TraitBase<ConcreteType, OperandsAreSignlessIntegerLike> {
public:
  static LogicalResult verifyTrait(Operation *op) {
    return impl::verifyOperandsAreSignlessIntegerLike(op);
  }
};

/// This class verifies that all operands of the specified op have the same
/// type.
template <typename ConcreteType>
class SameTypeOperands : public TraitBase<ConcreteType, SameTypeOperands> {
public:
  static LogicalResult verifyTrait(Operation *op) {
    return impl::verifySameTypeOperands(op);
  }
};

/// This class provides the API for a sub-set of ops that are known to be
/// constant-like. These are non-side effecting operations with one result and
/// zero operands that can always be folded to a specific attribute value.
template <typename ConcreteType>
class ConstantLike : public TraitBase<ConcreteType, ConstantLike> {
public:
  static LogicalResult verifyTrait(Operation *op) {
    static_assert(ConcreteType::template hasTrait<OneResult>(),
                  "expected operation to produce one result");
    static_assert(ConcreteType::template hasTrait<ZeroOperands>(),
                  "expected operation to take zero operands");
    // TODO: We should verify that the operation can always be folded, but this
    // requires that the attributes of the op already be verified. We should add
    // support for verifying traits "after" the operation to enable this use
    // case.
    return success();
  }
};

/// This class provides the API for ops that are known to be isolated from
/// above.
template <typename ConcreteType>
class IsIsolatedFromAbove
    : public TraitBase<ConcreteType, IsIsolatedFromAbove> {
public:
  static LogicalResult verifyRegionTrait(Operation *op) {
    return impl::verifyIsIsolatedFromAbove(op);
  }
};

/// A trait of region holding operations that defines a new scope for polyhedral
/// optimization purposes. Any SSA values of 'index' type that either dominate
/// such an operation or are used at the top-level of such an operation
/// automatically become valid symbols for the polyhedral scope defined by that
/// operation. For more details, see `Traits.md#AffineScope`.
template <typename ConcreteType>
class AffineScope : public TraitBase<ConcreteType, AffineScope> {
public:
  static LogicalResult verifyTrait(Operation *op) {
    static_assert(!ConcreteType::template hasTrait<ZeroRegions>(),
                  "expected operation to have one or more regions");
    return success();
  }
};

/// A trait of region holding operations that define a new scope for automatic
/// allocations, i.e., allocations that are freed when control is transferred
/// back from the operation's region. Any operations performing such allocations
/// (for eg. memref.alloca) will have their allocations automatically freed at
/// their closest enclosing operation with this trait.
template <typename ConcreteType>
class AutomaticAllocationScope
    : public TraitBase<ConcreteType, AutomaticAllocationScope> {
public:
  static LogicalResult verifyTrait(Operation *op) {
    static_assert(!ConcreteType::template hasTrait<ZeroRegions>(),
                  "expected operation to have one or more regions");
    return success();
  }
};

/// This class provides a verifier for ops that are expecting their parent
/// to be one of the given parent ops
template <typename... ParentOpTypes>
struct HasParent {
  template <typename ConcreteType>
  class Impl : public TraitBase<ConcreteType, Impl> {
  public:
    static LogicalResult verifyTrait(Operation *op) {
      if (llvm::isa_and_nonnull<ParentOpTypes...>(op->getParentOp()))
        return success();

      return op->emitOpError()
             << "expects parent op "
             << (sizeof...(ParentOpTypes) != 1 ? "to be one of '" : "'")
             << llvm::makeArrayRef({ParentOpTypes::getOperationName()...})
             << "'";
    }
  };
};

/// A trait for operations that have an attribute specifying operand segments.
///
/// Certain operations can have multiple variadic operands and their size
/// relationship is not always known statically. For such cases, we need
/// a per-op-instance specification to divide the operands into logical groups
/// or segments. This can be modeled by attributes. The attribute will be named
/// as `operand_segment_sizes`.
///
/// This trait verifies the attribute for specifying operand segments has
/// the correct type (1D vector) and values (non-negative), etc.
template <typename ConcreteType>
class AttrSizedOperandSegments
    : public TraitBase<ConcreteType, AttrSizedOperandSegments> {
public:
  static StringRef getOperandSegmentSizeAttr() {
    return "operand_segment_sizes";
  }

  static LogicalResult verifyTrait(Operation *op) {
    return ::mlir::OpTrait::impl::verifyOperandSizeAttr(
        op, getOperandSegmentSizeAttr());
  }
};

/// Similar to AttrSizedOperandSegments but used for results.
template <typename ConcreteType>
class AttrSizedResultSegments
    : public TraitBase<ConcreteType, AttrSizedResultSegments> {
public:
  static StringRef getResultSegmentSizeAttr() { return "result_segment_sizes"; }

  static LogicalResult verifyTrait(Operation *op) {
    return ::mlir::OpTrait::impl::verifyResultSizeAttr(
        op, getResultSegmentSizeAttr());
  }
};

/// This trait provides a verifier for ops that are expecting their regions to
/// not have any arguments
template <typename ConcrentType>
struct NoRegionArguments : public TraitBase<ConcrentType, NoRegionArguments> {
  static LogicalResult verifyTrait(Operation *op) {
    return ::mlir::OpTrait::impl::verifyNoRegionArguments(op);
  }
};

// This trait is used to flag operations that consume or produce
// values of `MemRef` type where those references can be 'normalized'.
// TODO: Right now, the operands of an operation are either all normalizable,
// or not. In the future, we may want to allow some of the operands to be
// normalizable.
template <typename ConcrentType>
struct MemRefsNormalizable
    : public TraitBase<ConcrentType, MemRefsNormalizable> {};

/// This trait tags element-wise ops on vectors or tensors.
///
/// NOTE: Not all ops that are "elementwise" in some abstract sense satisfy this
/// trait. In particular, broadcasting behavior is not allowed.
///
/// An `Elementwise` op must satisfy the following properties:
///
/// 1. If any result is a vector/tensor then at least one operand must also be a
///    vector/tensor.
/// 2. If any operand is a vector/tensor then there must be at least one result
///    and all results must be vectors/tensors.
/// 3. All operand and result vector/tensor types must be of the same shape. The
///    shape may be dynamic in which case the op's behaviour is undefined for
///    non-matching shapes.
/// 4. The operation must be elementwise on its vector/tensor operands and
///    results. When applied to single-element vectors/tensors, the result must
///    be the same per elememnt.
///
/// TODO: Avoid hardcoding vector/tensor, and generalize this trait to a new
/// interface `ElementwiseTypeInterface` that describes the container types for
/// which the operation is elementwise.
///
/// Rationale:
/// - 1. and 2. guarantee a well-defined iteration space and exclude the cases
///   of 0 non-scalar operands or 0 non-scalar results, which complicate a
///   generic definition of the iteration space.
/// - 3. guarantees that folding can be done across scalars/vectors/tensors with
///   the same pattern, as otherwise lots of special handling for type
///   mismatches would be needed.
/// - 4. guarantees that no error handling is needed. Higher-level dialects
///   should reify any needed guards or error handling code before lowering to
///   an `Elementwise` op.
template <typename ConcreteType>
struct Elementwise : public TraitBase<ConcreteType, Elementwise> {
  static LogicalResult verifyTrait(Operation *op) {
    return ::mlir::OpTrait::impl::verifyElementwise(op);
  }
};

/// This trait tags `Elementwise` operatons that can be systematically
/// scalarized. All vector/tensor operands and results are then replaced by
/// scalars of the respective element type. Semantically, this is the operation
/// on a single element of the vector/tensor.
///
/// Rationale:
/// Allow to define the vector/tensor semantics of elementwise operations based
/// on the same op's behavior on scalars. This provides a constructive procedure
/// for IR transformations to, e.g., create scalar loop bodies from tensor ops.
///
/// Example:
/// ```
/// %tensor_select = "arith.select"(%pred_tensor, %true_val, %false_val)
///                      : (tensor<?xi1>, tensor<?xf32>, tensor<?xf32>)
///                      -> tensor<?xf32>
/// ```
/// can be scalarized to
///
/// ```
/// %scalar_select = "arith.select"(%pred, %true_val_scalar, %false_val_scalar)
///                      : (i1, f32, f32) -> f32
/// ```
template <typename ConcreteType>
struct Scalarizable : public TraitBase<ConcreteType, Scalarizable> {
  static LogicalResult verifyTrait(Operation *op) {
    static_assert(
        ConcreteType::template hasTrait<Elementwise>(),
        "`Scalarizable` trait is only applicable to `Elementwise` ops.");
    return success();
  }
};

/// This trait tags `Elementwise` operatons that can be systematically
/// vectorized. All scalar operands and results are then replaced by vectors
/// with the respective element type. Semantically, this is the operation on
/// multiple elements simultaneously. See also `Tensorizable`.
///
/// Rationale:
/// Provide the reverse to `Scalarizable` which, when chained together, allows
/// reasoning about the relationship between the tensor and vector case.
/// Additionally, it permits reasoning about promoting scalars to vectors via
/// broadcasting in cases like `%select_scalar_pred` below.
template <typename ConcreteType>
struct Vectorizable : public TraitBase<ConcreteType, Vectorizable> {
  static LogicalResult verifyTrait(Operation *op) {
    static_assert(
        ConcreteType::template hasTrait<Elementwise>(),
        "`Vectorizable` trait is only applicable to `Elementwise` ops.");
    return success();
  }
};

/// This trait tags `Elementwise` operatons that can be systematically
/// tensorized. All scalar operands and results are then replaced by tensors
/// with the respective element type. Semantically, this is the operation on
/// multiple elements simultaneously. See also `Vectorizable`.
///
/// Rationale:
/// Provide the reverse to `Scalarizable` which, when chained together, allows
/// reasoning about the relationship between the tensor and vector case.
/// Additionally, it permits reasoning about promoting scalars to tensors via
/// broadcasting in cases like `%select_scalar_pred` below.
///
/// Examples:
/// ```
/// %scalar = "arith.addf"(%a, %b) : (f32, f32) -> f32
/// ```
/// can be tensorized to
/// ```
/// %tensor = "arith.addf"(%a, %b) : (tensor<?xf32>, tensor<?xf32>)
///               -> tensor<?xf32>
/// ```
///
/// ```
/// %scalar_pred = "arith.select"(%pred, %true_val, %false_val)
///                    : (i1, tensor<?xf32>, tensor<?xf32>) -> tensor<?xf32>
/// ```
/// can be tensorized to
/// ```
/// %tensor_pred = "arith.select"(%pred, %true_val, %false_val)
///                    : (tensor<?xi1>, tensor<?xf32>, tensor<?xf32>)
///                    -> tensor<?xf32>
/// ```
template <typename ConcreteType>
struct Tensorizable : public TraitBase<ConcreteType, Tensorizable> {
  static LogicalResult verifyTrait(Operation *op) {
    static_assert(
        ConcreteType::template hasTrait<Elementwise>(),
        "`Tensorizable` trait is only applicable to `Elementwise` ops.");
    return success();
  }
};

/// Together, `Elementwise`, `Scalarizable`, `Vectorizable`, and `Tensorizable`
/// provide an easy way for scalar operations to conveniently generalize their
/// behavior to vectors/tensors, and systematize conversion between these forms.
bool hasElementwiseMappableTraits(Operation *op);

} // namespace OpTrait

//===----------------------------------------------------------------------===//
// Internal Trait Utilities
//===----------------------------------------------------------------------===//

namespace op_definition_impl {
//===----------------------------------------------------------------------===//
// Trait Existence

/// Returns true if this given Trait ID matches the IDs of any of the provided
/// trait types `Traits`.
template <template <typename T> class... Traits>
static bool hasTrait(TypeID traitID) {
  TypeID traitIDs[] = {TypeID::get<Traits>()...};
  for (unsigned i = 0, e = sizeof...(Traits); i != e; ++i)
    if (traitIDs[i] == traitID)
      return true;
  return false;
}

//===----------------------------------------------------------------------===//
// Trait Folding

/// Trait to check if T provides a 'foldTrait' method for single result
/// operations.
template <typename T, typename... Args>
using has_single_result_fold_trait = decltype(T::foldTrait(
    std::declval<Operation *>(), std::declval<ArrayRef<Attribute>>()));
template <typename T>
using detect_has_single_result_fold_trait =
    llvm::is_detected<has_single_result_fold_trait, T>;
/// Trait to check if T provides a general 'foldTrait' method.
template <typename T, typename... Args>
using has_fold_trait =
    decltype(T::foldTrait(std::declval<Operation *>(),
                          std::declval<ArrayRef<Attribute>>(),
                          std::declval<SmallVectorImpl<OpFoldResult> &>()));
template <typename T>
using detect_has_fold_trait = llvm::is_detected<has_fold_trait, T>;
/// Trait to check if T provides any `foldTrait` method.
/// NOTE: This should use std::disjunction when C++17 is available.
template <typename T>
using detect_has_any_fold_trait =
    std::conditional_t<bool(detect_has_fold_trait<T>::value),
                       detect_has_fold_trait<T>,
                       detect_has_single_result_fold_trait<T>>;

/// Returns the result of folding a trait that implements a `foldTrait` function
/// that is specialized for operations that have a single result.
template <typename Trait>
static std::enable_if_t<detect_has_single_result_fold_trait<Trait>::value,
                        LogicalResult>
foldTrait(Operation *op, ArrayRef<Attribute> operands,
          SmallVectorImpl<OpFoldResult> &results) {
  assert(op->hasTrait<OpTrait::OneResult>() &&
         "expected trait on non single-result operation to implement the "
         "general `foldTrait` method");
  // If a previous trait has already been folded and replaced this operation, we
  // fail to fold this trait.
  if (!results.empty())
    return failure();

  if (OpFoldResult result = Trait::foldTrait(op, operands)) {
    if (result.template dyn_cast<Value>() != op->getResult(0))
      results.push_back(result);
    return success();
  }
  return failure();
}
/// Returns the result of folding a trait that implements a generalized
/// `foldTrait` function that is supports any operation type.
template <typename Trait>
static std::enable_if_t<detect_has_fold_trait<Trait>::value, LogicalResult>
foldTrait(Operation *op, ArrayRef<Attribute> operands,
          SmallVectorImpl<OpFoldResult> &results) {
  // If a previous trait has already been folded and replaced this operation, we
  // fail to fold this trait.
  return results.empty() ? Trait::foldTrait(op, operands, results) : failure();
}
template <typename Trait>
static inline std::enable_if_t<!detect_has_any_fold_trait<Trait>::value,
                               LogicalResult>
foldTrait(Operation *, ArrayRef<Attribute>, SmallVectorImpl<OpFoldResult> &) {
  return failure();
}

/// Given a tuple type containing a set of traits, return the result of folding
/// the given operation.
template <typename... Ts>
static LogicalResult foldTraits(Operation *op, ArrayRef<Attribute> operands,
                                SmallVectorImpl<OpFoldResult> &results) {
  bool anyFolded = false;
  (void)std::initializer_list<int>{
      (anyFolded |= succeeded(foldTrait<Ts>(op, operands, results)), 0)...};
  return success(anyFolded);
}

//===----------------------------------------------------------------------===//
// Trait Verification

/// Trait to check if T provides a `verifyTrait` method.
template <typename T, typename... Args>
using has_verify_trait = decltype(T::verifyTrait(std::declval<Operation *>()));
template <typename T>
using detect_has_verify_trait = llvm::is_detected<has_verify_trait, T>;

/// Trait to check if T provides a `verifyTrait` method.
template <typename T, typename... Args>
using has_verify_region_trait =
    decltype(T::verifyRegionTrait(std::declval<Operation *>()));
template <typename T>
using detect_has_verify_region_trait =
    llvm::is_detected<has_verify_region_trait, T>;

/// Verify the given trait if it provides a verifier.
template <typename T>
std::enable_if_t<detect_has_verify_trait<T>::value, LogicalResult>
verifyTrait(Operation *op) {
  return T::verifyTrait(op);
}
template <typename T>
inline std::enable_if_t<!detect_has_verify_trait<T>::value, LogicalResult>
verifyTrait(Operation *) {
  return success();
}

/// Given a set of traits, return the result of verifying the given operation.
template <typename... Ts>
LogicalResult verifyTraits(Operation *op) {
  LogicalResult result = success();
  (void)std::initializer_list<int>{
      (result = succeeded(result) ? verifyTrait<Ts>(op) : failure(), 0)...};
  return result;
}

/// Verify the given trait if it provides a region verifier.
template <typename T>
std::enable_if_t<detect_has_verify_region_trait<T>::value, LogicalResult>
verifyRegionTrait(Operation *op) {
  return T::verifyRegionTrait(op);
}
template <typename T>
inline std::enable_if_t<!detect_has_verify_region_trait<T>::value,
                        LogicalResult>
verifyRegionTrait(Operation *) {
  return success();
}

/// Given a set of traits, return the result of verifying the regions of the
/// given operation.
template <typename... Ts>
LogicalResult verifyRegionTraits(Operation *op) {
  (void)op;
  LogicalResult result = success();
  (void)std::initializer_list<int>{
      (result = succeeded(result) ? verifyRegionTrait<Ts>(op) : failure(),
       0)...};
  return result;
}
} // namespace op_definition_impl

//===----------------------------------------------------------------------===//
// Operation Definition classes
//===----------------------------------------------------------------------===//

/// This provides public APIs that all operations should have.  The template
/// argument 'ConcreteType' should be the concrete type by CRTP and the others
/// are base classes by the policy pattern.
template <typename ConcreteType, template <typename T> class... Traits>
class Op : public OpState, public Traits<ConcreteType>... {
public:
  /// Inherit getOperation from `OpState`.
  using OpState::getOperation;
  using OpState::verify;
  using OpState::verifyRegions;

  /// Return if this operation contains the provided trait.
  template <template <typename T> class Trait>
  static constexpr bool hasTrait() {
    return llvm::is_one_of<Trait<ConcreteType>, Traits<ConcreteType>...>::value;
  }

  /// Create a deep copy of this operation.
  ConcreteType clone() { return cast<ConcreteType>(getOperation()->clone()); }

  /// Create a partial copy of this operation without traversing into attached
  /// regions. The new operation will have the same number of regions as the
  /// original one, but they will be left empty.
  ConcreteType cloneWithoutRegions() {
    return cast<ConcreteType>(getOperation()->cloneWithoutRegions());
  }

  /// Return true if this "op class" can match against the specified operation.
  static bool classof(Operation *op) {
    if (auto info = op->getRegisteredInfo())
      return TypeID::get<ConcreteType>() == info->getTypeID();
#ifndef NDEBUG
    if (op->getName().getStringRef() == ConcreteType::getOperationName())
      llvm::report_fatal_error(
          "classof on '" + ConcreteType::getOperationName() +
          "' failed due to the operation not being registered");
#endif
    return false;
  }
  /// Provide `classof` support for other OpBase derived classes, such as
  /// Interfaces.
  template <typename T>
  static std::enable_if_t<std::is_base_of<OpState, T>::value, bool>
  classof(const T *op) {
    return classof(const_cast<T *>(op)->getOperation());
  }

  /// Expose the type we are instantiated on to template machinery that may want
  /// to introspect traits on this operation.
  using ConcreteOpType = ConcreteType;

  /// This is a public constructor.  Any op can be initialized to null.
  explicit Op() : OpState(nullptr) {}
  Op(std::nullptr_t) : OpState(nullptr) {}

  /// This is a public constructor to enable access via the llvm::cast family of
  /// methods. This should not be used directly.
  explicit Op(Operation *state) : OpState(state) {}

  /// Methods for supporting PointerLikeTypeTraits.
  const void *getAsOpaquePointer() const {
    return static_cast<const void *>((Operation *)*this);
  }
  static ConcreteOpType getFromOpaquePointer(const void *pointer) {
    return ConcreteOpType(
        reinterpret_cast<Operation *>(const_cast<void *>(pointer)));
  }

  /// Attach the given models as implementations of the corresponding
  /// interfaces for the concrete operation.
  template <typename... Models>
  static void attachInterface(MLIRContext &context) {
    Optional<RegisteredOperationName> info = RegisteredOperationName::lookup(
        ConcreteType::getOperationName(), &context);
    if (!info)
      llvm::report_fatal_error(
          "Attempting to attach an interface to an unregistered operation " +
          ConcreteType::getOperationName() + ".");
    (void)std::initializer_list<int>{(checkInterfaceTarget<Models>(), 0)...};
    info->attachInterface<Models...>();
  }

private:
  /// Trait to check if T provides a 'fold' method for a single result op.
  template <typename T, typename... Args>
  using has_single_result_fold =
      decltype(std::declval<T>().fold(std::declval<ArrayRef<Attribute>>()));
  template <typename T>
  using detect_has_single_result_fold =
      llvm::is_detected<has_single_result_fold, T>;
  /// Trait to check if T provides a general 'fold' method.
  template <typename T, typename... Args>
  using has_fold = decltype(std::declval<T>().fold(
      std::declval<ArrayRef<Attribute>>(),
      std::declval<SmallVectorImpl<OpFoldResult> &>()));
  template <typename T>
  using detect_has_fold = llvm::is_detected<has_fold, T>;
  /// Trait to check if T provides a 'print' method.
  template <typename T, typename... Args>
  using has_print =
      decltype(std::declval<T>().print(std::declval<OpAsmPrinter &>()));
  template <typename T>
  using detect_has_print = llvm::is_detected<has_print, T>;

  /// Trait to check if T provides a 'ConcreteEntity' type alias.
  template <typename T>
  using has_concrete_entity_t = typename T::ConcreteEntity;

  /// A struct-wrapped type alias to T::ConcreteEntity if provided and to
  /// ConcreteType otherwise. This is akin to std::conditional but doesn't fail
  /// on the missing typedef. Useful for checking if the interface is targeting
  /// the right class.
  template <typename T,
            bool = llvm::is_detected<has_concrete_entity_t, T>::value>
  struct InterfaceTargetOrOpT {
    using type = typename T::ConcreteEntity;
  };
  template <typename T>
  struct InterfaceTargetOrOpT<T, false> {
    using type = ConcreteType;
  };

  /// A hook for static assertion that the external interface model T is
  /// targeting the concrete type of this op. The model can also be a fallback
  /// model that works for every op.
  template <typename T>
  static void checkInterfaceTarget() {
    static_assert(std::is_same<typename InterfaceTargetOrOpT<T>::type,
                               ConcreteType>::value,
                  "attaching an interface to the wrong op kind");
  }

  /// Returns an interface map containing the interfaces registered to this
  /// operation.
  static detail::InterfaceMap getInterfaceMap() {
    return detail::InterfaceMap::template get<Traits<ConcreteType>...>();
  }

  /// Return the internal implementations of each of the OperationName
  /// hooks.
  /// Implementation of `FoldHookFn` OperationName hook.
  static OperationName::FoldHookFn getFoldHookFn() {
    return getFoldHookFnImpl<ConcreteType>();
  }
  /// The internal implementation of `getFoldHookFn` above that is invoked if
  /// the operation is single result and defines a `fold` method.
  template <typename ConcreteOpT>
  static std::enable_if_t<llvm::is_one_of<OpTrait::OneResult<ConcreteOpT>,
                                          Traits<ConcreteOpT>...>::value &&
                              detect_has_single_result_fold<ConcreteOpT>::value,
                          OperationName::FoldHookFn>
  getFoldHookFnImpl() {
    return [](Operation *op, ArrayRef<Attribute> operands,
              SmallVectorImpl<OpFoldResult> &results) {
      return foldSingleResultHook<ConcreteOpT>(op, operands, results);
    };
  }
  /// The internal implementation of `getFoldHookFn` above that is invoked if
  /// the operation is not single result and defines a `fold` method.
  template <typename ConcreteOpT>
  static std::enable_if_t<!llvm::is_one_of<OpTrait::OneResult<ConcreteOpT>,
                                           Traits<ConcreteOpT>...>::value &&
                              detect_has_fold<ConcreteOpT>::value,
                          OperationName::FoldHookFn>
  getFoldHookFnImpl() {
    return [](Operation *op, ArrayRef<Attribute> operands,
              SmallVectorImpl<OpFoldResult> &results) {
      return foldHook<ConcreteOpT>(op, operands, results);
    };
  }
  /// The internal implementation of `getFoldHookFn` above that is invoked if
  /// the operation does not define a `fold` method.
  template <typename ConcreteOpT>
  static std::enable_if_t<!detect_has_single_result_fold<ConcreteOpT>::value &&
                              !detect_has_fold<ConcreteOpT>::value,
                          OperationName::FoldHookFn>
  getFoldHookFnImpl() {
    return [](Operation *op, ArrayRef<Attribute> operands,
              SmallVectorImpl<OpFoldResult> &results) {
      // In this case, we only need to fold the traits of the operation.
      return op_definition_impl::foldTraits<Traits<ConcreteType>...>(
          op, operands, results);
    };
  }
  /// Return the result of folding a single result operation that defines a
  /// `fold` method.
  template <typename ConcreteOpT>
  static LogicalResult
  foldSingleResultHook(Operation *op, ArrayRef<Attribute> operands,
                       SmallVectorImpl<OpFoldResult> &results) {
    OpFoldResult result = cast<ConcreteOpT>(op).fold(operands);

    // If the fold failed or was in-place, try to fold the traits of the
    // operation.
    if (!result || result.template dyn_cast<Value>() == op->getResult(0)) {
      if (succeeded(op_definition_impl::foldTraits<Traits<ConcreteType>...>(
              op, operands, results)))
        return success();
      return success(static_cast<bool>(result));
    }
    results.push_back(result);
    return success();
  }
  /// Return the result of folding an operation that defines a `fold` method.
  template <typename ConcreteOpT>
  static LogicalResult foldHook(Operation *op, ArrayRef<Attribute> operands,
                                SmallVectorImpl<OpFoldResult> &results) {
    LogicalResult result = cast<ConcreteOpT>(op).fold(operands, results);

    // If the fold failed or was in-place, try to fold the traits of the
    // operation.
    if (failed(result) || results.empty()) {
      if (succeeded(op_definition_impl::foldTraits<Traits<ConcreteType>...>(
              op, operands, results)))
        return success();
    }
    return result;
  }

  /// Implementation of `GetCanonicalizationPatternsFn` OperationName hook.
  static OperationName::GetCanonicalizationPatternsFn
  getGetCanonicalizationPatternsFn() {
    return &ConcreteType::getCanonicalizationPatterns;
  }
  /// Implementation of `GetHasTraitFn`
  static OperationName::HasTraitFn getHasTraitFn() {
    return
        [](TypeID id) { return op_definition_impl::hasTrait<Traits...>(id); };
  }
  /// Implementation of `ParseAssemblyFn` OperationName hook.
  static OperationName::ParseAssemblyFn getParseAssemblyFn() {
    return &ConcreteType::parse;
  }
  /// Implementation of `PrintAssemblyFn` OperationName hook.
  static OperationName::PrintAssemblyFn getPrintAssemblyFn() {
    return getPrintAssemblyFnImpl<ConcreteType>();
  }
  /// The internal implementation of `getPrintAssemblyFn` that is invoked when
  /// the concrete operation does not define a `print` method.
  template <typename ConcreteOpT>
  static std::enable_if_t<!detect_has_print<ConcreteOpT>::value,
                          OperationName::PrintAssemblyFn>
  getPrintAssemblyFnImpl() {
    return [](Operation *op, OpAsmPrinter &printer, StringRef defaultDialect) {
      return OpState::print(op, printer, defaultDialect);
    };
  }
  /// The internal implementation of `getPrintAssemblyFn` that is invoked when
  /// the concrete operation defines a `print` method.
  template <typename ConcreteOpT>
  static std::enable_if_t<detect_has_print<ConcreteOpT>::value,
                          OperationName::PrintAssemblyFn>
  getPrintAssemblyFnImpl() {
    return &printAssembly;
  }
  static void printAssembly(Operation *op, OpAsmPrinter &p,
                            StringRef defaultDialect) {
    OpState::printOpName(op, p, defaultDialect);
    return cast<ConcreteType>(op).print(p);
  }
  /// Implementation of `PopulateDefaultAttrsFn` OperationName hook.
  static OperationName::PopulateDefaultAttrsFn getPopulateDefaultAttrsFn() {
    return ConcreteType::populateDefaultAttrs;
  }
  /// Implementation of `VerifyInvariantsFn` OperationName hook.
  static LogicalResult verifyInvariants(Operation *op) {
    static_assert(hasNoDataMembers(),
                  "Op class shouldn't define new data members");
    return failure(
        failed(op_definition_impl::verifyTraits<Traits<ConcreteType>...>(op)) ||
        failed(cast<ConcreteType>(op).verify()));
  }
  static OperationName::VerifyInvariantsFn getVerifyInvariantsFn() {
    return static_cast<LogicalResult (*)(Operation *)>(&verifyInvariants);
  }
  /// Implementation of `VerifyRegionInvariantsFn` OperationName hook.
  static LogicalResult verifyRegionInvariants(Operation *op) {
    static_assert(hasNoDataMembers(),
                  "Op class shouldn't define new data members");
    return failure(
        failed(op_definition_impl::verifyRegionTraits<Traits<ConcreteType>...>(
            op)) ||
        failed(cast<ConcreteType>(op).verifyRegions()));
  }
  static OperationName::VerifyRegionInvariantsFn getVerifyRegionInvariantsFn() {
    return static_cast<LogicalResult (*)(Operation *)>(&verifyRegionInvariants);
  }

  static constexpr bool hasNoDataMembers() {
    // Checking that the derived class does not define any member by comparing
    // its size to an ad-hoc EmptyOp.
    class EmptyOp : public Op<EmptyOp, Traits...> {};
    return sizeof(ConcreteType) == sizeof(EmptyOp);
  }

  /// Allow access to internal implementation methods.
  friend RegisteredOperationName;
};

/// This class represents the base of an operation interface. See the definition
/// of `detail::Interface` for requirements on the `Traits` type.
template <typename ConcreteType, typename Traits>
class OpInterface
    : public detail::Interface<ConcreteType, Operation *, Traits,
                               Op<ConcreteType>, OpTrait::TraitBase> {
public:
  using Base = OpInterface<ConcreteType, Traits>;
  using InterfaceBase = detail::Interface<ConcreteType, Operation *, Traits,
                                          Op<ConcreteType>, OpTrait::TraitBase>;

  /// Inherit the base class constructor.
  using InterfaceBase::InterfaceBase;

protected:
  /// Returns the impl interface instance for the given operation.
  static typename InterfaceBase::Concept *getInterfaceFor(Operation *op) {
    OperationName name = op->getName();

    // Access the raw interface from the operation info.
    if (Optional<RegisteredOperationName> rInfo = name.getRegisteredInfo()) {
      if (auto *opIface = rInfo->getInterface<ConcreteType>())
        return opIface;
      // Fallback to the dialect to provide it with a chance to implement this
      // interface for this operation.
      return rInfo->getDialect().getRegisteredInterfaceForOp<ConcreteType>(
          op->getName());
    }
    // Fallback to the dialect to provide it with a chance to implement this
    // interface for this operation.
    if (Dialect *dialect = name.getDialect())
      return dialect->getRegisteredInterfaceForOp<ConcreteType>(name);
    return nullptr;
  }

  /// Allow access to `getInterfaceFor`.
  friend InterfaceBase;
};

//===----------------------------------------------------------------------===//
// CastOpInterface utilities
//===----------------------------------------------------------------------===//

// These functions are out-of-line implementations of the methods in
// CastOpInterface, which avoids them being template instantiated/duplicated.
namespace impl {
/// Attempt to fold the given cast operation.
LogicalResult foldCastInterfaceOp(Operation *op,
                                  ArrayRef<Attribute> attrOperands,
                                  SmallVectorImpl<OpFoldResult> &foldResults);
/// Attempt to verify the given cast operation.
LogicalResult verifyCastInterfaceOp(
    Operation *op, function_ref<bool(TypeRange, TypeRange)> areCastCompatible);
} // namespace impl
} // namespace mlir

namespace llvm {

template <typename T>
struct DenseMapInfo<T,
                    std::enable_if_t<std::is_base_of<mlir::OpState, T>::value &&
                                     !mlir::detail::IsInterface<T>::value>> {
  static inline T getEmptyKey() {
    auto *pointer = llvm::DenseMapInfo<void *>::getEmptyKey();
    return T::getFromOpaquePointer(pointer);
  }
  static inline T getTombstoneKey() {
    auto *pointer = llvm::DenseMapInfo<void *>::getTombstoneKey();
    return T::getFromOpaquePointer(pointer);
  }
  static unsigned getHashValue(T val) {
    return hash_value(val.getAsOpaquePointer());
  }
  static bool isEqual(T lhs, T rhs) { return lhs == rhs; }
};

} // namespace llvm

#endif
